package cn.qianxun.meta.gateway.filter;


import cn.qianxun.meta.common.core.utils.StringUtils;
import com.alibaba.fastjson.JSON;
import cn.qianxun.meta.common.core.domain.Result;
import cn.qianxun.meta.gateway.utils.SqLinjectionRuleUtils;
import cn.qianxun.meta.gateway.utils.XssCleanRuleUtils;
import io.netty.buffer.ByteBufAllocator;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static com.alibaba.cloud.commons.lang.StringUtils.isBlank;

/**
 * @Classname SqLinjectionFilter
 * @Description sql过滤工具
 * @Date 2021/8/11 17:38
 * @Created by fuzhilin
 */
@Slf4j
@Component
@ConfigurationProperties(prefix = "sqlxss.ignore")
public class SqLinjectionFilter implements GlobalFilter, Ordered {

    private List<String> sqlinjectionHttpUrls = new ArrayList<>();;

    @SneakyThrows
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        // grab configuration from Config object
        log.debug("----自定义防sql注入网关全局过滤器生效----");
        ServerHttpRequest serverHttpRequest = exchange.getRequest();
        HttpMethod method = serverHttpRequest.getMethod();
        String contentType = serverHttpRequest.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
//        if(StrUtil.isEmpty(contentType)){
//            contentType = serverHttpRequest.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE.toLowerCase());
//            if(StrUtil.isEmpty(contentType)){
//                return setUnauthorizedResponseContentType(exchange);
//            }
//        }
        URI uri = exchange.getRequest().getURI();
        //1.动态刷新 sql注入的过滤的路径
        String path = serverHttpRequest.getURI().getRawPath();
        if( StringUtils.matches(path, sqlinjectionHttpUrls)){
            log.error("请求【{}】在sql注入过滤白名单中，直接放行", path);
            return chain.filter(exchange);
        }
//        System.out.println("method:"+method+"；contentType"+contentType);
//        // 非json类型，不过滤
//        boolean flagjson = WebFluxUtils.isJsonRequest(exchange);
//        System.out.println("flagjson:"+flagjson);
        Boolean postFlag = (method == HttpMethod.POST || method == HttpMethod.PUT) &&
                (MediaType.APPLICATION_FORM_URLENCODED_VALUE.equalsIgnoreCase(contentType) || MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(contentType));
        //过滤get请求
        if (method == HttpMethod.GET) {
            String rawQuery = uri.getRawQuery();
            if (isBlank(rawQuery)) {
                return chain.filter(exchange);
            }
            log.debug("原请求参数为：{}", rawQuery);
            // 执行XSS清理
            rawQuery = XssCleanRuleUtils.xssGetClean(rawQuery);
            log.debug("修改后参数为：{}", rawQuery);

            // 执行sql注入校验清理
            boolean chkRet = SqLinjectionRuleUtils.getRequestSqlKeyWordsCheck(rawQuery);
            //    如果存在sql注入,直接拦截请求
            if (chkRet) {
                log.error("请求【" + uri.getRawPath() + uri.getRawQuery() + "】参数中包含不允许sql的关键词, 请求拒绝");
                return setUnauthorizedResponse(exchange);
            }
            //透传参数，不对参数做任何处理
            return chain.filter(exchange);
        } else if (postFlag) {
            String finalContentType = contentType;
            return DataBufferUtils.join(serverHttpRequest.getBody()).flatMap(d -> Mono.just(Optional.of(d))).defaultIfEmpty(
                    Optional.empty())
                    .flatMap(optional -> {
                        // 取出body中的参数
                        String bodyString = "";
                        if (optional.isPresent()) {
                            byte[] oldBytes = new byte[optional.get().readableByteCount()];
                            optional.get().read(oldBytes);
                            bodyString = new String(oldBytes, StandardCharsets.UTF_8);
                        }
                        HttpHeaders httpHeaders = serverHttpRequest.getHeaders();
                        boolean chkRet = true;
                        // 执行XSS清理
                        log.debug("{} - [{}：{}] XSS处理前参数：{}", method, uri.getPath(), bodyString);
                        bodyString = XssCleanRuleUtils.xssPostClean(bodyString);
                        log.info("{} - [{}：{}] XSS处理后参数：{}", method, uri.getPath(), bodyString);
                        if (MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(finalContentType)) {
                            //如果MediaType是json才执行json方式验证
                            chkRet = SqLinjectionRuleUtils.postRequestSqlKeyWordsCheck(bodyString);
                        } else {
                            //form表单方式，需要走get请求
                            try {
                                chkRet = SqLinjectionRuleUtils.getRequestSqlKeyWordsCheck(bodyString);
                            } catch (UnsupportedEncodingException e) {
                                chkRet = true;
                                e.printStackTrace();
                            }
                        }
                        //  如果存在sql注入,直接拦截请求
                        if (chkRet) {
                            log.error("{} - [{}] 参数：{}, 包含不允许sql的关键词，请求拒绝", method, uri.getPath(), bodyString);
                            return setUnauthorizedResponse(exchange);
                        }
                        ServerHttpRequest newRequest = serverHttpRequest.mutate().uri(uri).build();
                        // 重新构造body
                        byte[] newBytes = bodyString.getBytes(StandardCharsets.UTF_8);
                        DataBuffer bodyDataBuffer = toDataBuffer(newBytes);
                        Flux<DataBuffer> bodyFlux = Flux.just(bodyDataBuffer);
                        // 重新构造header
                        HttpHeaders headers = new HttpHeaders();
                        headers.putAll(httpHeaders);
                        // 由于修改了传递参数，需要重新设置CONTENT_LENGTH，长度是字节长度，不是字符串长度
                        int length = newBytes.length;
                        headers.remove(HttpHeaders.CONTENT_LENGTH);
                        headers.setContentLength(length);
                        headers.set(HttpHeaders.CONTENT_TYPE, "application/json;charset=utf8");
                        // 重写ServerHttpRequestDecorator，修改了body和header，重写getBody和getHeaders方法
                        newRequest = new ServerHttpRequestDecorator(newRequest) {
                            @Override
                            public Flux<DataBuffer> getBody() {
                                return bodyFlux;
                            }

                            @Override
                            public HttpHeaders getHeaders() {
                                return headers;
                            }
                        };
                        return chain.filter(exchange.mutate().request(newRequest).build());
                    });
        } else {
            return chain.filter(exchange);
        }

    }


    // 自定义过滤器执行的顺序，数值越大越靠后执行，越小就越先执行
    @Override
    public int getOrder() {
        return -200;
    }

    /**
     * 设置403拦截状态
     */
    private Mono<Void> setUnauthorizedResponse(ServerWebExchange exchange) {
        ServerHttpResponse response = exchange.getResponse();
        String msg = "参数中包含不允许sql的关键词, 请求拒绝。";
        response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
        response.setStatusCode(HttpStatus.FORBIDDEN);
        return response.writeWith(Mono.fromSupplier(() -> {
            DataBufferFactory bufferFactory = response.bufferFactory();
            return bufferFactory.wrap(JSON.toJSONBytes(Result.fail(msg)));
        }));
    }
    /**
     * 设置403拦截状态
     */
    private Mono<Void> setUnauthorizedResponseContentType(ServerWebExchange exchange) {
        ServerHttpResponse response = exchange.getResponse();
        String msg = "请求中没有包含Content-Type, 请求拒绝。";
        response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
        response.setStatusCode(HttpStatus.FORBIDDEN);
        return response.writeWith(Mono.fromSupplier(() -> {
            DataBufferFactory bufferFactory = response.bufferFactory();
            return bufferFactory.wrap(JSON.toJSONBytes(Result.fail(msg)));
        }));
    }

    /**
     * 字节数组转DataBuffer
     *
     * @param bytes 字节数组
     * @return DataBuffer
     */
    private DataBuffer toDataBuffer(byte[] bytes) {
        NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }

    public List<String> getSqlinjectionHttpUrls() {
        return sqlinjectionHttpUrls;
    }

    public void setSqlinjectionHttpUrls(List<String> sqlinjectionHttpUrls) {
        this.sqlinjectionHttpUrls = sqlinjectionHttpUrls;
    }
}
